import os

import torch
import random
from tqdm import tqdm


import torch.nn as nn
from torch import optim as optim
from matplotlib import pyplot as plt
import torch.nn.functional as F

from ModularUtils.FunctionsConstant import asKey
from ModularUtils.ControllerConstants import get_multiple_labels_fill
from ModularUtils.DigitImageGeneration.mnist_image_generation import plot_trained_digits


# Creating a DeepAutoencoder class
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial9/AE_CIFAR10.html
class DeepAutoencoder(torch.nn.Module):
    def __init__(self, Exp, par_dim, latent_dim):
        super().__init__()

        label_dim = par_dim
        num_input_channels=3
        base_channel_size=Exp.IMAGE_SIZE
        c_hid = base_channel_size
        latent_dim=latent_dim
        act_fn: object = nn.GELU

        # for 32
        self.encoder = nn.Sequential(
            nn.Conv2d(num_input_channels+label_dim, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),

            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),

            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),

            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),

            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

        self.linear = nn.Sequential(
            nn.Linear(latent_dim + label_dim, 2 * 16 * c_hid),   #latent dim -> 32*32= 1024
            act_fn()
        )

        # new
        # self.encoder = nn.Sequential(
        #     nn.Conv2d(num_input_channels + label_dim, c_hid, kernel_size=3, padding=1, stride=2),  # 32x32 => 16x16
        #     act_fn(),
        #     nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
        #     act_fn(),
        #     nn.Conv2d(c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 16x16 => 8x8
        #     act_fn(),
        #     nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
        #     act_fn(),
        #     nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
        #     act_fn(),
        #     nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
        #     act_fn(),
        #     nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1, stride=2),  # 8x8 => 4x4
        #     act_fn(),
        #     nn.Flatten(),  # Image grid to single feature vector
        #     nn.Linear(2 * 16 * c_hid, latent_dim)
        # )
        # #
        #
        #
        # self.decoder = nn.Sequential(
        #     nn.ConvTranspose2d(2 * c_hid, 2 * c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
        #     # 128x4x4 => 128x8x8
        #     act_fn(),
        #     nn.Conv2d(2 * c_hid, 2 * c_hid, kernel_size=3, padding=1),
        #     act_fn(),
        #     nn.ConvTranspose2d(2 * c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
        #     # 128x8x8 => 64x16x16
        #     act_fn(),
        #     nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
        #     act_fn(),
        #     nn.ConvTranspose2d(c_hid, int(c_hid / 2), kernel_size=3, output_padding=1, padding=1, stride=2),
        #     # 64x16x16 => 32x32x32
        #     act_fn(),
        #     nn.Conv2d(int(c_hid / 2), int(c_hid / 2), kernel_size=3, padding=1),
        #     act_fn(),
        #     nn.ConvTranspose2d(int(c_hid / 2), num_input_channels, kernel_size=3, output_padding=1, padding=1,
        #                        stride=2),  # 32x32x32 => 3x64x64
        #     nn.Tanh()  # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        # )
        #
        # self.linear = nn.Sequential(
        #     nn.Linear(latent_dim + label_dim, 2 * 16 * c_hid),  # latent dim -> 32*32= 1024
        #     act_fn()
        # )


    def forward(self, Exp, x, true_y, dim_list,  isLatent=True):  #y is label/parent


            # true_y = get_multiple_labels_fill(Exp, true_y, dim_list, isImage_labels=False)
            y= true_y.unsqueeze(2).unsqueeze(3).repeat(1,1,Exp.IMAGE_SIZE,Exp.IMAGE_SIZE).to(Exp.DEVICE)

            # if isOnehot==False:
            #     y = get_multiple_labels_fill(Exp, true_y.view(-1, 1), dim_list, isImage_labels=True,
            #                              more_dimsize=Exp.IMAGE_SIZE)
            # else:
            #     y=true_y


            x = torch.cat([x,y], 1)
            z = self.encoder(x)

            if isLatent:
                return z


            z = torch.cat([z, true_y], 1)
            z = self.linear(z)
            z = z.reshape(z.shape[0], -1, 4, 4)
            x_hat = self.decoder(z)
            return x_hat





def train_encoders(Exp, rep_mech , label_generators, optimizers, image_data_dict, label_data):
    criterion = torch.nn.MSELoss()
    num_epochs = 200

    image_data_loader = torch.utils.data.DataLoader(dataset=image_data_dict[asKey({})],
                                                    batch_size=Exp.batch_size,
                                                    shuffle=False)


    parents= list(set(Exp.Observed_DAG[rep_mech]) - set(Exp.image_labels))
    # pid= Exp.label_names.index(parent)

    pid= [Exp.label_names.index(par) for par in parents]
    dim_list=  [Exp.label_dim[par] for par in parents]
    label_loader= torch.utils.data.DataLoader(dataset=label_data[asKey({})]["obs"][:,pid],
                                                    batch_size=Exp.batch_size,
                                                    shuffle=False,
                                                    # num_workers=4, #greater the num_workers, more efficiently the CPU load data and less the GPU has to wait
                                                    # pin_memory=True  #If you load your samples in the Dataset on CPU and would like to push it during training to the GPU, you can speed up the host to device transfer by enabling pin_memory.
                                              )
    train_loss = []

    # Dictionary that will store the
    # different images and outputs for
    # various epochs
    outputs = {}

    batch_size = len(image_data_loader)

    # Training loop starts
    for epoch in range(num_epochs):

        # Initializing variable for storing
        # loss
        running_loss = 0

        label_iterator= iter(label_loader)


        # Iterating over the training dataset
        for bno, batch in enumerate (tqdm(image_data_loader)):
            if type(batch) is list:
                img = batch[0].to(Exp.DEVICE)
            else:
                img= batch.to(Exp.DEVICE)
            labels= next(label_iterator)

            # Generating output
            # labels_fill = get_multiple_labels_fill(Exp, label.view(-1,1), dim_list , isImage_labels=True,
            #                                             more_dimsize=Exp.IMAGE_SIZE)

            out = label_generators[rep_mech](Exp, img,labels, dim_list, isLatent=False)

            # Calculating loss
            loss = criterion(out, img)

            # Updating weights according
            # to the calculated loss
            optimizers[rep_mech].zero_grad()
            loss.backward()
            optimizers[rep_mech].step()

            # Incrementing loss
            running_loss += loss.item()


        # Averaging out loss over entire batch
        running_loss /= batch_size
        train_loss.append(round(running_loss,4))

        ll = -min(10, len(train_loss))
        print("epoch:", epoch, train_loss[ll:])

        # Storing useful images and
        # reconstructed outputs for the last batch
        outputs[epoch + 1] = {'img': img, 'out': out}


        if epoch%5==0:
            rind= random.randint(0, img.shape[0])
            img = img[rind].permute(1, 2, 0).detach().cpu().numpy()
            plot_trained_digits(1, 1, [img], f'Epoch-{epoch}-Real', Exp.SAVED_PATH+"/VAE_plots")

            out = out[rind].permute(1, 2, 0).detach().cpu().numpy()
            plot_trained_digits(1, 1, [out], f'Epoch-{epoch}-fake',Exp.SAVED_PATH+"/VAE_plots")

            # saving models
            print(Exp.curr_epoochs, ":Encoder model saved at ", Exp.SAVED_PATH)
            print("=> Saving checkpoint")
            gen_checkpoint = {"epoch": Exp.curr_epoochs,
                              "trained": [lb for lb, isLoad in Exp.load_which_models.items() if isLoad == True]}
            for label in label_generators:
                gen_checkpoint[label + "state_dict"] = label_generators[label].state_dict()
                gen_checkpoint["optimizer" + label] = optimizers[label].state_dict()

            os.makedirs(Exp.SAVED_PATH + f"/checkpoints_generators", exist_ok=True)
            gfile = Exp.SAVED_PATH + f"/checkpoints_generators/epoch{Exp.curr_epoochs:03}.pth"
            last_gfile = Exp.SAVED_PATH + f"/checkpoints_generators/epochLast.pth"
            torch.save(gen_checkpoint, gfile)
            torch.save(gen_checkpoint, last_gfile)


    # Plotting the training loss
    plt.plot(range(1, num_epochs + 1), train_loss)
    plt.xlabel("Number of epochs")
    plt.ylabel("Training Loss")
    plt.show()

